from .Matrix import Matrix
from .Vector import Vector
from ._gloabl import is_zero

class LinearSystem:
    """线性系统"""
    def __init__(self,A,b):
        """
        构造一个增广矩阵
        :param A: 矩阵
        :param b: 向量
        """
        assert A.row_num() == len(b),"矩阵的行数和向量的个数相等，线性系统一一对应"
        self._m = A.row_num()
        self._n = A.col_num()
        #增广矩阵
        self.Ab = [Vector(A.row_vector(i).underlying_list() + [b[i]]) for i in range(self._m)]

        #存储主元的列表
        self.pivots = []

    def gauss_jordan_elimination(self):
        """高斯约旦消元法"""
        #前向消元
        self._forward()
        #后向消元
        self._backward()

        for i in range(len(self.pivots),self._m):
            if not is_zero(self.Ab[i][-1]):
                return False
        return True

    def _forward(self):
        """前向消元"""
        i,k = 0,0
        while i < self._m and k < self._n:
            """看self.Ab[i][k]是否为主元"""
            max_row = self._max_row(i,k,self._m)
            self.Ab[i],self.Ab[max_row] = self.Ab[max_row],self.Ab[i]
            #主元进行归一
            if is_zero(self.Ab[i][k]):
                k += 1
            else:
                self.Ab[i] = self.Ab[i] / self.Ab[i][k]
                #相减
                for j in range(i+1,self._m):
                    self.Ab[j] = self.Ab[j] - self.Ab[j][k] * self.Ab[i]
                self.pivots.append(k)
                i += 1

    def _backward(self):
        """后向消元"""
        n = len(self.pivots)
        for i in range(n-1,-1,-1):
            k = self.pivots[i]
            for j in range(i-1,-1,-1):
                self.Ab[j] = self.Ab[j] - self.Ab[j][k] * self.Ab[i]

    def _max_row(self,index_i,index_j,n):
        """
        找到 index行 到 n行之间最大主元的那个数及所在行
        """
        best,ret = self.Ab[index_i][index_j],index_i
        for i in range(index_i+1,n):
            if self.Ab[i][index_j] > best:
                best,ret = self.Ab[i][index_i],index_j

        return ret

    def fancy_print(self):
        """打印增广矩阵"""
        for i in range(self._m):
            print(" ".join(str(self.Ab[i][j]) for j in range(self._n)),end=" ")
            print("| ",self.Ab[i][-1])